import numpy as np
import math
import torch
import os
import simclr
from LogME import LogME
import torch.nn.functional as F

device = 'cuda'

def get_all_fea(trainloader, model, args, forward_func):
    feas = None
    targets = None
    
    with torch.no_grad():
        for batch_idx, (inputs, target) in enumerate(trainloader): 
            inputs, target = inputs.to(device), target.to(device) 
            
            if args.use_pred:
                _, fea = forward_func(model, inputs)
                fea = torch.nn.Softmax(dim=1)(fea)
            else:
                fea, _ = forward_func(model, inputs)


            if feas is None:
                feas = fea
                targets = target
            else:
                feas = torch.cat((feas, fea), 0)
                targets = torch.cat((targets, target), 0)

    return feas, targets

def torch_cov(input_vec:torch.tensor):
    x = input_vec- input_vec.mean(0)
    # print(x)
    # for i in range(x.size(0))
    # cov_matrix = torch.matmul(x.T, x) / (x.shape[0]-1)
    cov_matrix = torch.matmul(x.T, x) / (x.size(0))
    return cov_matrix

def calc_db2(sgm, sgm1, sgm2):
    rst = 1.
    sgm = torch.diag(sgm)
    sgm1 = torch.diag(sgm1)
    sgm2 = torch.diag(sgm2)

    for i in range(sgm.size(0)):
        rst *= sgm[i]/sgm1[i]
        rst *= sgm[i]/sgm2[i]
        
    return rst

class PCA:
    def __init__(self,output_dim) -> None:
        self.output_dim = output_dim
    
    def fit(self,X_data):
        N = len(X_data)
        H = torch.eye(n=N)-1/N*(torch.matmul(torch.ones(size=(N,1)),torch.ones(size=(1,N))))
        X_data = torch.matmul(H.cuda(),X_data)
        _,_,v = torch.svd(X_data)
        self.base = v[:,:self.output_dim]

    def fit_transform(self,X_data):
        self.fit(X_data)
        return self.transform(X_data)

    def transform(self,X_data):
        return torch.matmul(X_data,self.base)

    def inverse_transform(self,X_data):
        return torch.matmul(X_data,self.base.T)

def entropy(p, prob=True, mean=True, sum=False):
    # print(p)
    if prob:
        p = F.softmax(p,dim=1)
    en = -torch.sum(p * torch.log(p+1e-5), 1)
    if mean:
        return torch.mean(en) 
    elif sum:
        return torch.sum(en)
    else:
        return en

def cls_density(fea, T = 0.05):

    mus = []
    sigmas = []

    for i in range(len(fea)):
        mu = fea[i].mean(dim=0).unsqueeze(1)
        sigma = torch_cov(fea[i])
        sigma = torch.diag(sigma).diag_embed()
        
        mus.append(mu)
        sigmas.append(sigma)

    score = 0.
    num = 0.

    num_cls_tgt = len(fea)
    similar_mtx = torch.ones((num_cls_tgt, num_cls_tgt-1)).cuda()*1.

    for i in range(len(fea)):
        for j in range(i, len(fea)):  #  j > i
            if i == j:
                continue
                
            bc_coef = BC(mus[i], mus[j], sigmas[i], sigmas[j])  ## higher and closer
            # print(bc_coef)
            # print(bc_dis)
            similar_mtx[i, j - 1] = bc_coef
            similar_mtx[j, i] = bc_coef
    # print(similar_mtx)
    # return score/num
    # print(similar_mtx.size())
    
    # exit()
    # mask = torch.eye(similar_mtx.size(0), similar_mtx.size(0)).bool().cuda()
    # similar_mtx.masked_fill_(mask, 0.)
    # print(similar_mtx.min(), similar_mtx.max(), similar_mtx.mean())
    # exit()

    score = entropy(similar_mtx / T)
    return score

def nc_cls_density_Estimate_cb(trainloader, model, args, forward_func):
    def NC_wolog(feas, labels, cls_num):
        """ ugly but works """
        cls_feas = [None]*cls_num
        for i in range(feas.size(0)):
            lb = labels[i]
            if cls_feas[lb] is None:
                cls_feas[lb] = feas[i].unsqueeze(0)
            else:
                cls_feas[lb] = torch.cat((cls_feas[lb], feas[i].unsqueeze(0)), 0)


        Sigma_W = 0.
        for i in range(cls_num):
            Sigma_W += torch_cov(cls_feas[i]) / cls_num


        cls_mean = None

        for i in range(cls_num):
            if cls_mean is None:
                cls_mean = cls_feas[i].mean(0).unsqueeze(0)
            else:
                cls_mean = torch.cat((cls_mean, cls_feas[i].mean(0).unsqueeze(0)), 0)


        Sigma_B = torch_cov(cls_mean)

        score = Sigma_W.matmul( torch.linalg.pinv(Sigma_B) ).trace() / cls_num
        return -score.item()

    fea_bank = [None] * args.num_cls_tgt
    all_fea = None
    all_lb = None
    with torch.no_grad():
        for batch_idx, (inputs, target) in enumerate(trainloader):
            inputs, target = inputs.to(device), target.to(device)
            # fea, _ = forward_func(model, inputs)
            if args.use_pred:
                _, fea = forward_func(model, inputs)
                fea = torch.nn.Softmax(dim=1)(fea)
            else:
                fea, _ = forward_func(model, inputs)

            if all_fea is None:
                all_fea = fea
                all_lb = target
            else:
                all_fea = torch.cat((all_fea, fea), 0)
                all_lb = torch.cat((all_lb, target), 0)

    


    nc_score = - NC_wolog(all_fea, all_lb, args.num_cls_tgt)
    
    all_fea = PCA(64).fit_transform(all_fea)
    
    all_fea = F.normalize(all_fea)

    for i in range(all_lb.size(0)):
        c = int(all_lb[i])
        fea_c = all_fea[i].unsqueeze(0)
        if fea_bank[c] is None:
            fea_bank[c] = fea_c
        else:
            fea_bank[c] = torch.cat((fea_bank[c], fea_c), 0)

    density = cls_density(fea_bank, T = args.T).item()
    if density <= 0:
        density = 1e-5

    print(nc_score, density)
    return -nc_score, density

